Introduction to generative function combinators

Author

McCoy Reynolds Becker

Published

December 1, 2022

Abstract
This notebook is designed to give the reader an introduction to generative function combinators - higher-order functions which accept generative functions and return new generative functions, typically transforming the generative structure to encode a pattern of repeated computation. In this notebook, we focus on vector-shaped combinators - combinators which transform input generative structure into (mapped or scanned) vectorial structure.
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

Gen helps probabilistic programmers design and implement models and inference algorithms by automating the (often) complicated inference math. The generative function interface is the key abstraction layer which provides this automation. Generative function language designers can extend the interface to new generative function objects - providing domain-specific patterns and optimizations which users can automatically take advantage of.

One key class of generative function languages are combinators - higher-order functions which accept generative functions as input, and produce a new generative function type as an output.

Combinators functionally transform the generative structure that we pass into them, expressing useful patterns - including chain-like computations, IID sampling patterns, or generative computations which form grammar-like structures.

Combinators also expose optimization opportunities - by registering the patterns as generative functions, implementors (e.g. the library authors) can specialize the implementation of the generative function interface methods. Users of combinators can then take advantage of this interface specialization to express asymptotically optimal updates (useful in e.g. MCMC kernels), or optimized importance weight calculations.

In this notebook, we’ll be discussing Unfold - a combinator for expressing generative computations which are reminiscent of state-space (or Markov) models. To keep things simple, we’ll explore a hidden Markov model example - but combinator usage generalizes to models with much richer structure.

Introducing Unfold

Let’s discuss Unfold.1

  • 1 A quick reminder: when in doubt, you can use the console from console = genjax.pretty() to inspect the classes which we discuss in the notebooks.

  • console.help(genjax.UnfoldCombinator)
    ╭─ <class 'genjax.generative_functions.combinators.unfold.UnfoldCombinator'> ──╮
     class UnfoldCombinator(kernel: genjax.core.datatypes.GenerativeFunction, max 
                                                                                  
     :code:`UnfoldCombinator` accepts a single kernel generative function         
     as input and a static unroll length which specifies how many iterations      
     to run the chain for.                                                        
                                                                                  
     A kernel generative function is one which accepts and returns                
     the same signature of arguments. Under the hood, :code:`UnfoldCombinator`    
     is implemented using :code:`jax.lax.scan` - which has the same               
     requirements.                                                                
                                                                                  
     Parameters                                                                   
     ----------                                                                   
                                                                                  
     gen_fn: :code:`GenerativeFunction`                                           
         A single *kernel* `GenerativeFunction` instance.                         
                                                                                  
     length: :code:`Int`                                                          
         An integer specifying the unroll length of the chain of applications.    
                                                                                  
     Returns                                                                      
     -------                                                                      
     :code:`UnfoldCombinator`                                                     
         A single :code:`UnfoldCombinator` generative function which              
         implements the generative function interface using a scan-like           
         pattern. This generative function will perform a dependent-for           
         iteration (passing the return value of generative function application)  
         to the next iteration for :code:`length` number of steps.                
         The programmer must provide an initial value to start the chain of       
         iterations off.                                                          
                                                                                  
     Example                                                                      
     -------                                                                      
                                                                                  
     .. jupyter-execute::                                                         
                                                                                  
         import jax                                                               
         import genjax                                                            
                                                                                  
                                                                                  
         @genjax.gen                                                              
         def random_walk(key, prev):                                              
             key, x = genjax.trace("x", genjax.Normal)(key, prev, 1.0)            
             return (key, x)                                                      
                                                                                  
                                                                                  
         unfold = genjax.UnfoldCombinator(random_walk, 1000)                      
         init = 0.5                                                               
         key = jax.random.PRNGKey(314159)                                         
         key, tr = jax.jit(genjax.simulate(unfold))(key, (999, init))             
         print(tr)                                                                
                                                                                  
             assess = def assess(self, key, chm, args):                           
            flatten = def flatten(self):                                          
     get_trace_type = def get_trace_type(self, key, args, **kwargs):              
         importance = def importance(self, key, chm, args):                       
           simulate = def simulate(self, key, args):                              
          unflatten = def unflatten(data, xs):                                    
              unzip = def unzip(self, key: jaxtyping.Integer[Array, '...'],       
                      fixed: genjax.core.datatypes.ChoiceMap) ->                  
                      Tuple[jaxInteger[Array, '...'],                             
                      Callable[[genjax.core.datatypes.ChoiceMap, Tuple], float],  
                      Callable[[genjax.core.datatypes.ChoiceMap, Tuple], Any]]:   
             update = def update(self, key, prev, chm, diffs):                    
    ╰──────────────────────────────────────────────────────────────────────────────╯
    

    By inspecting the class, we gain some confidence that it indeed implements the generative function interface. Below, we show a TikZ picture which illustrates the pattern which Unfold captures.

    Using Unfold

    How do we make an instance of Unfold? Given an existing generative function which is a kernel - a kernel accepts and returns the same type signature - we can create a valid Unfold instance.2

  • 2 This is not strictly true. Unfold also allows you to pass in a set of static arguments which are provided to the kernel after the state argument, unchanged, at each time step. We show this at the bottom of the notebook.

  • Here’s an example kernel:

    @genjax.gen
    def kernel(key, prev_latent):
        key, new_latent = genjax.trace("z", genjax.Normal)(key, prev_latent, 1.0)
        key, new_obs = genjax.trace("x", genjax.Normal)(key, new_latent, 1.0)
        return key, new_latent
    key, tr = jax.jit(kernel.simulate)(key, (0.3,))
    tr
    
    
    BuiltinTrace
    ├── gen_fn
    │   └── BuiltinGenerativeFunction
    │       └── source
    │           └── <function kernel>
    ├── args
    │   └── tuple
    │       └── f32[]
    ├── retval
    │   └── f32[]
    ├── choices
    │   └── BuiltinChoiceMap
    │       ├── z
    │       │   └── DistributionTrace
    │       │       ├── gen_fn
    │       │       │   └── _Normal
    │       │       ├── args
    │       │       │   └── tuple
    │       │       │       ├── f32[]
    │       │       │       └── f32[]
    │       │       ├── value
    │       │       │   └── ValueChoiceMap
    │       │       │       └── value
    │       │       │           └── f32[]
    │       │       └── score
    │       │           └── f32[]
    │       └── x
    │           └── DistributionTrace
    │               ├── gen_fn
    │               │   └── _Normal
    │               ├── args
    │               │   └── tuple
    │               │       ├── f32[]
    │               │       └── f32[]
    │               ├── value
    │               │   └── ValueChoiceMap
    │               │       └── value
    │               │           └── f32[]
    │               └── score
    │                   └── f32[]
    ├── cache
    │   └── BuiltinTrie
    └── score
        └── f32[]
    

    To create an Unfold instance, we provide two things:

    • The kernel generative function.
    • A static maximum unroll chain argument. Dynamically, Unfold may not unroll all the way up to this maximum - but for JAX/XLA compilation, we need to provide this maximum value as an invariant upper bound for any invocation of Unfold.
    chain = genjax.Unfold(kernel, max_length=10)
    chain
    
    
    UnfoldCombinator
    ├── kernel
    │   └── BuiltinGenerativeFunction
    │       └── source
    │           └── <function kernel>
    └── max_length
        └── (lit) 10
    

    To invoke an interface method, the arguments which Unfold expects is a Tuple, where the first element is the maximum index in the resulting chain, and the second element is the initial state.

    Usage of index argument vs. a length argument

    Note how we’ve bolded index above - think of the index value as denoting an upper bound on active indices for the resulting chain. An active index is one in which the value was evolved using the kernel from the previous value. Passing in index = 5 means: all values after return[5] are not evolved, they’re just filled with the return[5] value.

    Indexing follows Python convention - so e.g. passing in 0 as the index means that a single application of the kernel was applied to the state, before evolution was halted and evolved statically.

    key, tr = jax.jit(chain.simulate)(key, (5, 0.3))
    tr
    
    
    UnfoldTrace
    ├── gen_fn
    │   └── UnfoldCombinator
    │       ├── kernel
    │       │   └── BuiltinGenerativeFunction
    │       │       └── source
    │       │           └── <function kernel>
    │       └── max_length
    │           └── (lit) 10
    ├── indices
    │   └── i32[10]
    ├── inner
    │   └── BuiltinTrace
    │       ├── gen_fn
    │       │   └── BuiltinGenerativeFunction
    │       │       └── source
    │       │           └── <function kernel>
    │       ├── args
    │       │   └── tuple
    │       │       └── f32[10]
    │       ├── retval
    │       │   └── f32[10]
    │       ├── choices
    │       │   └── BuiltinChoiceMap
    │       │       ├── z
    │       │       │   └── DistributionTrace
    │       │       │       ├── gen_fn
    │       │       │       │   └── _Normal
    │       │       │       ├── args
    │       │       │       │   └── tuple
    │       │       │       │       ├── f32[10]
    │       │       │       │       └── f32[10]
    │       │       │       ├── value
    │       │       │       │   └── ValueChoiceMap
    │       │       │       │       └── value
    │       │       │       │           └── f32[10]
    │       │       │       └── score
    │       │       │           └── f32[10]
    │       │       └── x
    │       │           └── DistributionTrace
    │       │               ├── gen_fn
    │       │               │   └── _Normal
    │       │               ├── args
    │       │               │   └── tuple
    │       │               │       ├── f32[10]
    │       │               │       └── f32[10]
    │       │               ├── value
    │       │               │   └── ValueChoiceMap
    │       │               │       └── value
    │       │               │           └── f32[10]
    │       │               └── score
    │       │                   └── f32[10]
    │       ├── cache
    │       │   └── BuiltinTrie
    │       └── score
    │           └── f32[10]
    ├── args
    │   └── tuple
    │       ├── i32[]
    │       └── f32[]
    ├── retval
    │   └── f32[10]
    └── score
        └── f32[]
    
    tr = chain._make_zero_trace(key, (5, 0.3))
    tr
    
    
    UnfoldTrace
    ├── gen_fn
    │   └── UnfoldCombinator
    │       ├── kernel
    │       │   └── BuiltinGenerativeFunction
    │       │       └── source
    │       │           └── <function kernel>
    │       └── max_length
    │           └── (lit) 10
    ├── indices
    │   └── i32[10]
    ├── inner
    │   └── BuiltinTrace
    │       ├── gen_fn
    │       │   └── BuiltinGenerativeFunction
    │       │       └── source
    │       │           └── <function kernel>
    │       ├── args
    │       │   └── tuple
    │       │       └── f32[10]
    │       ├── retval
    │       │   └── f32[10]
    │       ├── choices
    │       │   └── BuiltinChoiceMap
    │       │       ├── z
    │       │       │   └── DistributionTrace
    │       │       │       ├── gen_fn
    │       │       │       │   └── _Normal
    │       │       │       ├── args
    │       │       │       │   └── tuple
    │       │       │       │       ├── f32[10]
    │       │       │       │       └── f32[10]
    │       │       │       ├── value
    │       │       │       │   └── ValueChoiceMap
    │       │       │       │       └── value
    │       │       │       │           └── f32[10]
    │       │       │       └── score
    │       │       │           └── f32[10]
    │       │       └── x
    │       │           └── DistributionTrace
    │       │               ├── gen_fn
    │       │               │   └── _Normal
    │       │               ├── args
    │       │               │   └── tuple
    │       │               │       ├── f32[10]
    │       │               │       └── f32[10]
    │       │               ├── value
    │       │               │   └── ValueChoiceMap
    │       │               │       └── value
    │       │               │           └── f32[10]
    │       │               └── score
    │       │                   └── f32[10]
    │       ├── cache
    │       │   └── BuiltinTrie
    │       └── score
    │           └── f32[10]
    ├── args
    │   └── tuple
    │       ├── i32[]
    │       └── f32[]
    ├── retval
    │   └── f32[10]
    └── score
        └── f32[]
    
    tr.indices
    Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
    
    tr.get_retval()
    Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
    

    Note how tr.indices keep track of where the chain stopped evolving, according to the index argument to Unfold. In tr.get_retval(), we see that the final dynamic value (afterwards, evolution stops) occurs at index = 5.

    Combinator choice maps

    Typically, each combinator has a unique choice map. The choice map simultaneously represents the structure of the generative choices which the transformed combinator generative function makes, as well as optimization opportunities which a user can take advantage of.

    Let’s study the choice map for UnfoldTrace.

    chm = tr.get_choices()
    chm
    
    
    VectorChoiceMap
    ├── indices
    │   └── i32[10]
    └── inner
        └── BuiltinTrace
            ├── gen_fn
            │   └── BuiltinGenerativeFunction
            │       └── source
            │           └── <function kernel>
            ├── args
            │   └── tuple
            │       └── f32[10]
            ├── retval
            │   └── f32[10]
            ├── choices
            │   └── BuiltinChoiceMap
            │       ├── z
            │       │   └── DistributionTrace
            │       │       ├── gen_fn
            │       │       │   └── _Normal
            │       │       ├── args
            │       │       │   └── tuple
            │       │       │       ├── f32[10]
            │       │       │       └── f32[10]
            │       │       ├── value
            │       │       │   └── ValueChoiceMap
            │       │       │       └── value
            │       │       │           └── f32[10]
            │       │       └── score
            │       │           └── f32[10]
            │       └── x
            │           └── DistributionTrace
            │               ├── gen_fn
            │               │   └── _Normal
            │               ├── args
            │               │   └── tuple
            │               │       ├── f32[10]
            │               │       └── f32[10]
            │               ├── value
            │               │   └── ValueChoiceMap
            │               │       └── value
            │               │           └── f32[10]
            │               └── score
            │                   └── f32[10]
            ├── cache
            │   └── BuiltinTrie
            └── score
                └── f32[10]
    

    Again, let’s look at the indices.

    chm.indices
    Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
    

    No surprises - the choice map also keeps track of which indices are active, and which indices are inactive.

    Inactive indices do not participate in inference metadata computations - so e.g. if we ask for the score of the trace:

    tr.get_score()
    Array(0., dtype=float32)
    

    The score is the same as the sum of sub-trace scores from [0:5].

    np.sum(
        tr.get_subtree("z").get_score()[0:5] + tr.get_subtree("x").get_score()[0:5]
    )
    Array(0., dtype=float32)
    

    The reason why we have an index argument is that we can dynamically choose how much of the chain contributes to the generative computation. This index argument can come from other generative function - it need not be a JAX trace-time static value.

    With this in mind, it’s best to think of Unfold as representing a space of processes which unroll up to some maximum static length - but the active generative process can halt before that maximum length.